{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a3d54dd1",
   "metadata": {},
   "source": [
    "# Export FFNet ONNX and TensorFlow Lite model.\n",
    "\n",
    "## Reference\n",
    "- https://github.com/Qualcomm-AI-research/FFNet\n",
    "- https://github.com/PINTO0309/onnx2tf\n",
    "- https://github.com/onnx/tensorflow-onnx"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39d96932",
   "metadata": {},
   "source": [
    "## Clone the repository. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13dbc645",
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/Qualcomm-AI-research/FFNet.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "820a76a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd FFNet"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5a4d9f3",
   "metadata": {},
   "source": [
    "Change `model_weights_base_path` in config file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "329a1eee",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "cat > ./config.patch\n",
    "diff --git a/config.py b/config.py\n",
    "index a9cf687..30efc27 100644\n",
    "--- a/config.py\n",
    "+++ b/config.py\n",
    "@@ -3,7 +3,7 @@\n",
    " \n",
    " imagenet_base_path = \"/workspace/imagenet/\"\n",
    " cityscapes_base_path = \"/workspace/cityscapes/\"\n",
    "-model_weights_base_path = \"/workspace/ffnet_weights/\"\n",
    "+model_weights_base_path = \"./model_weights/\"\n",
    " \n",
    " CITYSCAPES_MEAN = [0.485, 0.456, 0.406]\n",
    " CITYSCAPES_STD = [0.229, 0.224, 0.225]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d80645f",
   "metadata": {},
   "outputs": [],
   "source": [
    "!patch < ./config.patch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "994a3cc7",
   "metadata": {},
   "source": [
    "## Download the trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a41027fe-24e8-4530-a233-3756f1502825",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import glob\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from models.model_registry import model_entrypoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f800287e",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = [\n",
    "    \"ffnet101.zip\",\n",
    "    \"ffnet122N.zip\",\n",
    "    \"ffnet122NS.zip\",\n",
    "    \"ffnet134.zip\",\n",
    "    \"ffnet150.zip\",\n",
    "    \"ffnet150S.zip\",\n",
    "    \"ffnet18.zip\",\n",
    "    \"ffnet34.zip\",\n",
    "    \"ffnet40S.zip\",\n",
    "    \"ffnet46N.zip\",\n",
    "    \"ffnet46NS.zip\",\n",
    "    \"ffnet50.zip\",\n",
    "    \"ffnet54S.zip\",\n",
    "    \"ffnet56.zip\",\n",
    "    \"ffnet74N.zip\",\n",
    "    \"ffnet74NS.zip\",\n",
    "    \"ffnet78S.zip\",\n",
    "    \"ffnet86.zip\",\n",
    "    \"ffnet86S.zip\",   \n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d4bbfcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "download_path = os.path.join(\".\", \"model_weights\")\n",
    "\n",
    "for model in MODELS:\n",
    "    download_url = \"https://github.com/Qualcomm-AI-research/FFNet/releases/download/models/\" + model\n",
    "    file_path = os.path.join(\".\", \"model_weights\", model)\n",
    "\n",
    "    !wget $download_url -P $download_path\n",
    "    !unzip $file_path -d $download_path"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4626543a",
   "metadata": {},
   "source": [
    "## Export ONNX and TensorFlow Lite model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10c00758",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEG_MODEL_NAME = {\n",
    "    (\"segmentation_ffnet101_AAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet50_AAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet150_AAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet134_AAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet86_AAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet56_AAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet34_AAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet150_ABB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet86_ABB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet56_ABB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet34_ABB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet150S_BBB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet86S_BBB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet122N_CBB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet74N_CBB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet46N_CBB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet101_dAAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet50_dAAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet150_dAAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet134_dAAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet86_dAAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet56_dAAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet34_dAAA\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet18_dAAA\", (1024, 2048)),\n",
    "    # (\"segmentation_ffnet150_dAAC\", (1024, 2048)), size mismatch\n",
    "    # (\"segmentation_ffnet86_dAAC\", (512, 1024)), size mismatch\n",
    "    # (\"segmentation_ffnet34_dAAC\", (512, 1024)), size mismatch\n",
    "    # (\"segmentation_ffnet18_dAAC\", (512, 1024)), size mismatch\n",
    "    (\"segmentation_ffnet150S_dBBB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet86S_dBBB\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet86S_dBBB_mobile\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet78S_dBBB_mobile\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet54S_dBBB_mobile\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet40S_dBBB_mobile\", (1024, 2048)),\n",
    "    (\"segmentation_ffnet150S_BBB_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet150S_BBB_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet86S_BBB_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet86S_BBB_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet78S_BBB_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet78S_BBB_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet54S_BBB_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet54S_BBB_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet40S_BBB_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet40S_BBB_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet150S_BCC_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet40S_BBB_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet40S_BBB_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet86S_BCC_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet86S_BCC_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet78S_BCC_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet78S_BCC_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet54S_BCC_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet54S_BCC_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet40S_BCC_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet40S_BCC_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet122NS_CBB_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet122NS_CBB_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet74NS_CBB_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet74NS_CBB_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet46NS_CBB_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet46NS_CBB_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet122NS_CCC_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet122NS_CCC_mobile\", (512, 1024)),\n",
    "    # (\"segmentation_ffnet74NS_CCC_mobile_pre_down\", (512, 1024)),  RuntimeError: Error(s) in loading state_dict for FFNet:\n",
    "    (\"segmentation_ffnet74NS_CCC_mobile\", (512, 1024)),\n",
    "    (\"segmentation_ffnet46NS_CCC_mobile_pre_down\", (512, 1024)),\n",
    "    (\"segmentation_ffnet46NS_CCC_mobile\", (512, 1024))\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ad298a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExportFFNet(torch.nn.Module):\n",
    "    def __init__(self, model_name):\n",
    "        super().__init__()\n",
    "        self.model = model_entrypoint(model_name)()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.model(x)\n",
    "        x = F.interpolate(x, (1024, 2048), mode=\"bilinear\", align_corners=True)\n",
    "        x = torch.argmax(x, dim=1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdf93e5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path = os.path.join(\".\", \"ffnet_models\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ef1b6b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(output_path):\n",
    "    shutil.rmtree(output_path)\n",
    "os.mkdir(output_path)\n",
    "\n",
    "if os.path.exists(os.path.join(\".\", \"saved_model\")):\n",
    "    shutil.rmtree(os.path.join(\".\", \"saved_model\"))\n",
    "for tmp_onnx_path in glob.glob(os.path.join(\".\", \"*.onnx\"), recursive=False):\n",
    "    os.remove(tmp_onnx_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e94e3884",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name, size in SEG_MODEL_NAME:\n",
    "    height, width = size\n",
    "\n",
    "    print(\"----- Start {} -----\".format(model_name), )\n",
    "\n",
    "    if os.path.exists(os.path.join(output_path, model_name + \"_fused_argmax.onnx\")):\n",
    "        print(\"{} has already been exported.\".format(model_name))\n",
    "        continue\n",
    "\n",
    "    # load model and export onnx.\n",
    "    model = ExportFFNet(model_name=model_name)\n",
    "    dummy_input = torch.randn(1, 3, height, width, device=\"cpu\")\n",
    "    tmp_onnx_path = os.path.join(\".\", model_name + \".onnx\")\n",
    "    torch.onnx.export(\n",
    "        model,\n",
    "        dummy_input,\n",
    "        tmp_onnx_path,\n",
    "        verbose=False,\n",
    "        input_names=[ \"input1\" ],\n",
    "        output_names=[ \"output1\" ]\n",
    "    )\n",
    "\n",
    "    # Convert default argmax.\n",
    "    output_onnx_path = os.path.join(output_path, model_name + \".onnx\")\n",
    "    tflite_float32_path = os.path.join(\".\", \"saved_model\", model_name + \"_float32.tflite\")\n",
    "\n",
    "    !onnx2tf -i $tmp_onnx_path\n",
    "    !python -m tf2onnx.convert --tflite $tflite_float32_path --inputs-as-nchw inputs_0 --output $output_onnx_path\n",
    "    for tflite_file in glob.glob(os.path.join(\".\", \"saved_model\", \"*.tflite\"), recursive=True):\n",
    "        shutil.copyfile(tflite_file, os.path.join(output_path, os.path.basename(tflite_file)))\n",
    "    shutil.rmtree(os.path.join(\".\", \"saved_model\"))\n",
    "    \n",
    "    # Convert fused-argmax\n",
    "    src_name = tmp_onnx_path\n",
    "    tmp_onnx_path = os.path.join(\".\", model_name + \"_fused_argmax.onnx\")\n",
    "    os.rename(src_name, tmp_onnx_path)\n",
    "    output_onnx_path = os.path.join(output_path, model_name + \"_fused_argmax.onnx\")\n",
    "    tflite_float32_path = os.path.join(\".\", \"saved_model\", model_name + \"_fused_argmax_float32.tflite\")\n",
    "\n",
    "    !onnx2tf -i $tmp_onnx_path -rafi64\n",
    "    !python -m tf2onnx.convert --tflite $tflite_float32_path --inputs-as-nchw inputs_0 --output $output_onnx_path\n",
    "    for tflite_file in glob.glob(os.path.join(\".\", \"saved_model\", \"*.tflite\"), recursive=True):\n",
    "        shutil.copyfile(tflite_file, os.path.join(output_path, os.path.basename(tflite_file)))\n",
    "    os.remove(tmp_onnx_path) \n",
    "    shutil.rmtree(os.path.join(\".\", \"saved_model\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da8711bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar czf ffnet_models.tar.gz ffnet_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bea3b409",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
